{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "import xgboost\n",
    "import tensorflow as tf\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "import tensorflow.contrib.eager as tfe\n",
    "import functools\n",
    "from sklearn.metrics import r2_score\n",
    "from src.utils import mape_score\n",
    "from sklearn.linear_model import LinearRegression\n",
    "from sklearn.linear_model import Ridge, Lasso\n",
    "from sklearn.svm import SVR\n",
    "import pandas as pd\n",
    "tfe.enable_eager_execution()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "df = pd.read_csv('./data/sourceAB_shanghai_mall_processedB.csv', index_col=0)\n",
    "\n",
    "Xa = df.filter(regex='f_a.*', axis=1).values\n",
    "Xb = df.filter(regex='f_b.*', axis=1).values\n",
    "y = np.asarray(df['hotness'].values, dtype=np.float64)\n",
    "y = np.log(df['hotness'].values)\n",
    "\n",
    "Xa = StandardScaler().fit_transform(Xa)\n",
    "Xb = StandardScaler().fit_transform(Xb)\n",
    "Xab = np.concatenate((Xa, Xb), axis=1)\n",
    "# y = (y - y.mean() / y.std())\n",
    "\n",
    "numpy_all = {\n",
    "    'feature_a': Xa, \n",
    "    'feature_b': Xb, \n",
    "    'feature_ab': Xab,\n",
    "    'target': y\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "batch_size = 32\n",
    "max_step = 3000\n",
    "feature_key = 'feature_ab'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from src.utils import validation_tf, kfold_tf, kfold_validation\n",
    "from src.tf_models import TypicalNetwork"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "use_feature: feature_ab\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-5-503cb52c85bc>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0muse_model\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mTypicalNetwork\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0mmodel_para\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfeature_key\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfeature_key\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhidden_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m40\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0msteps\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mr2s\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moneminus_mapes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mkfold_tf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mXa\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mXb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mXab\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0muse_model\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_para\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_step\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      4\u001b[0m \u001b[0mdf\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDataFrame\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0;34m'step'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0msteps\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'r2'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mr2s\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'1-mape'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0moneminus_mapes\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/Users/hzy/temp/multimodal_perception/src/utils.py\u001b[0m in \u001b[0;36mkfold_tf\u001b[0;34m(Xa, Xb, Xab, y, kfold_k, use_model, model_para, batch_size, max_step, verbose)\u001b[0m\n\u001b[1;32m    139\u001b[0m         \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreset_default_graph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    140\u001b[0m         \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0muse_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mmodel_para\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 141\u001b[0;31m         \u001b[0msteps\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mr2s\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moneminus_mapes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalidation_tf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdataset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_step\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnumpy_dataset_test\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    142\u001b[0m         \u001b[0mr2s_kfold\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mr2s\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    143\u001b[0m         \u001b[0moneminus_mape_kfold\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moneminus_mapes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/Users/hzy/temp/multimodal_perception/src/utils.py\u001b[0m in \u001b[0;36mvalidation_tf\u001b[0;34m(model, dataset, batch_size, max_step, numpy_dataset_test, verbose)\u001b[0m\n\u001b[1;32m     95\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0mmax_step\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     96\u001b[0m             \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 97\u001b[0;31m         \u001b[0mloss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrads\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcal_value_gradients\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_batch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_batch\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'target'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     98\u001b[0m         \u001b[0mopt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply_gradients\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgrads\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     99\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0;36m100\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/Users/hzy/anaconda3/anaconda/lib/python3.6/site-packages/tensorflow/python/eager/backprop.py\u001b[0m in \u001b[0;36mgrad_fn\u001b[0;34m(*args)\u001b[0m\n\u001b[1;32m    353\u001b[0m                                            \u001b[0mpopped_tape\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    354\u001b[0m                                            \u001b[0mnest\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mflatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mend_node\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 355\u001b[0;31m                                            sources)\n\u001b[0m\u001b[1;32m    356\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mend_node\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvariables\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    357\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/Users/hzy/anaconda3/anaconda/lib/python3.6/site-packages/tensorflow/python/eager/imperative_grad.py\u001b[0m in \u001b[0;36mimperative_grad\u001b[0;34m(vspace, tape, target, sources, output_gradients)\u001b[0m\n\u001b[1;32m    198\u001b[0m         \u001b[0mout_gradients\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvspace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maggregate_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout_gradients\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    199\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 200\u001b[0;31m     \u001b[0min_gradients\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mop_trace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout_gradients\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    201\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mop_trace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minput_ids\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    202\u001b[0m       \u001b[0;32mif\u001b[0m \u001b[0min_gradients\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/Users/hzy/anaconda3/anaconda/lib/python3.6/site-packages/tensorflow/python/eager/backprop.py\u001b[0m in \u001b[0;36mgrad_fn\u001b[0;34m(*orig_outputs)\u001b[0m\n\u001b[1;32m    278\u001b[0m     \u001b[0;34m\"\"\"Generated gradient function.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    279\u001b[0m     result = _magic_gradient_function(op_name, attrs, num_inputs,\n\u001b[0;32m--> 280\u001b[0;31m                                       op_inputs, op_outputs, orig_outputs)\n\u001b[0m\u001b[1;32m    281\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0m_tracing\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    282\u001b[0m       print(\"Gradient for\", (name if name else op_name), \"inputs\", op_inputs,\n",
      "\u001b[0;32m/Users/hzy/anaconda3/anaconda/lib/python3.6/site-packages/tensorflow/python/eager/backprop.py\u001b[0m in \u001b[0;36m_magic_gradient_function\u001b[0;34m(op_name, attr_tuple, num_inputs, inputs, outputs, out_grads)\u001b[0m\n\u001b[1;32m    107\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mnum_inputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    108\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 109\u001b[0;31m   \u001b[0;32mreturn\u001b[0m \u001b[0mgrad_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmock_op\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mout_grads\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    110\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    111\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/Users/hzy/anaconda3/anaconda/lib/python3.6/site-packages/tensorflow/python/ops/math_grad.py\u001b[0m in \u001b[0;36m_RealDivGrad\u001b[0;34m(op, grad)\u001b[0m\n\u001b[1;32m    825\u001b[0m   return (array_ops.reshape(\n\u001b[1;32m    826\u001b[0m       \u001b[0mmath_ops\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreduce_sum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmath_ops\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrealdiv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 827\u001b[0;31m       \u001b[0msx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marray_ops\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    828\u001b[0m           math_ops.reduce_sum(grad * math_ops.realdiv(math_ops.realdiv(-x, y), y),\n\u001b[1;32m    829\u001b[0m                               ry), sy))\n",
      "\u001b[0;32m/Users/hzy/anaconda3/anaconda/lib/python3.6/site-packages/tensorflow/python/ops/gen_array_ops.py\u001b[0m in \u001b[0;36mreshape\u001b[0;34m(tensor, shape, name)\u001b[0m\n\u001b[1;32m   3985\u001b[0m     \u001b[0m_attrs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m\"T\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_op\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_attr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"T\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"Tshape\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_op\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_attr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Tshape\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   3986\u001b[0m   \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3987\u001b[0;31m     \u001b[0m_attr_T\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_execute\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0margs_to_matching_eager\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_ctx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   3988\u001b[0m     \u001b[0m_attr_T\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_attr_T\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_datatype_enum\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   3989\u001b[0m     \u001b[0m_attr_Tshape\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_execute\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0margs_to_matching_eager\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_ctx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_dtypes\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mint32\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/Users/hzy/anaconda3/anaconda/lib/python3.6/site-packages/tensorflow/python/eager/execute.py\u001b[0m in \u001b[0;36margs_to_matching_eager\u001b[0;34m(l, ctx, default_dtype)\u001b[0m\n\u001b[1;32m    171\u001b[0m   \u001b[0mEagerTensor\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mops\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mEagerTensor\u001b[0m  \u001b[0;31m# pylint: disable=invalid-name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    172\u001b[0m   \u001b[0;32mif\u001b[0m \u001b[0mall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mEagerTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0ml\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 173\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0ml\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ml\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    174\u001b[0m   \u001b[0;31m# TODO(josh11b): Could we do a better job if we also passed in the\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    175\u001b[0m   \u001b[0;31m# allowed dtypes when that was known?\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/Users/hzy/anaconda3/anaconda/lib/python3.6/site-packages/tensorflow/python/framework/ops.py\u001b[0m in \u001b[0;36mdtype\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    608\u001b[0m     \u001b[0mtape\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdelete_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtid\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    609\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 610\u001b[0;31m   \u001b[0;34m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    611\u001b[0m   \u001b[0;32mdef\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    612\u001b[0m     \u001b[0;31m# Note: using the intern table directly here as this is\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "use_model = TypicalNetwork\n",
    "model_para = dict(feature_key=feature_key, hidden_size=40)\n",
    "steps, r2s, oneminus_mapes = kfold_tf(Xa, Xb, Xab, y, 4, use_model, model_para, batch_size, max_step, verbose=False)\n",
    "df = pd.DataFrame({'step': steps, 'r2':r2s, '1-mape': oneminus_mapes})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### normal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.87563777309855984"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "kfold_validation(\n",
    "    numpy_all[feature_key],\n",
    "    numpy_all['target'],\n",
    "    4,\n",
    "    LinearRegression(),\n",
    "    use_mape=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.87349493044999549"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "kfold_validation(\n",
    "    numpy_all[feature_key],\n",
    "    numpy_all['target'],\n",
    "    4,\n",
    "    Lasso(alpha=0.01),\n",
    "    use_mape=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.87239746587019307"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "kfold_validation(\n",
    "    numpy_all[feature_key],\n",
    "    numpy_all['target'],\n",
    "    4,\n",
    "    Ridge(alpha=0.2),\n",
    "    use_mape=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.88020918331539721"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "kfold_validation(\n",
    "    numpy_all[feature_key],\n",
    "    numpy_all['target'],\n",
    "    4,\n",
    "    xgboost.XGBRegressor(),\n",
    "    use_mape=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.88127732859726104"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "kfold_validation(\n",
    "    numpy_all[feature_key],\n",
    "    numpy_all['target'],\n",
    "    4,\n",
    "    SVR(C=2.0),\n",
    "    use_mape=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2 nets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from src.tf_models import ProposedTwoLayers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "batch_size = 32\n",
    "max_step = 5000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "collapsed": true,
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "use_model = ProposedTwoLayers\n",
    "model_para = dict(hidden_a = 20, hidden_b = 20, use_dropout = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "use_feature: feature_ab\n",
      "fold: 1 / 4  r2: 0.575496 1-mape: 0.865433\n",
      "use_feature: feature_ab\n"
     ]
    }
   ],
   "source": [
    "steps, r2s, oneminus_mapes = kfold_tf(Xa, Xb, Xab, y, 4, use_model, model_para, batch_size, max_step, verbose=False)\n",
    "df = pd.DataFrame({'step': steps, 'r2':r2s, '1-mape': oneminus_mapes})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style>\n",
       "    .dataframe thead tr:only-child th {\n",
       "        text-align: right;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: left;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>1-mape</th>\n",
       "      <th>r2</th>\n",
       "      <th>step</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>40</th>\n",
       "      <td>0.874003</td>\n",
       "      <td>0.463630</td>\n",
       "      <td>4100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>0.873947</td>\n",
       "      <td>0.488478</td>\n",
       "      <td>1800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>0.873017</td>\n",
       "      <td>0.483799</td>\n",
       "      <td>2100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>44</th>\n",
       "      <td>0.872835</td>\n",
       "      <td>0.452362</td>\n",
       "      <td>4500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>34</th>\n",
       "      <td>0.871858</td>\n",
       "      <td>0.480349</td>\n",
       "      <td>3500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>47</th>\n",
       "      <td>0.871580</td>\n",
       "      <td>0.444040</td>\n",
       "      <td>4800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>0.871447</td>\n",
       "      <td>0.476700</td>\n",
       "      <td>2200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>38</th>\n",
       "      <td>0.870389</td>\n",
       "      <td>0.445788</td>\n",
       "      <td>3900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>0.870136</td>\n",
       "      <td>0.470281</td>\n",
       "      <td>1900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>46</th>\n",
       "      <td>0.869319</td>\n",
       "      <td>0.435511</td>\n",
       "      <td>4700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>41</th>\n",
       "      <td>0.868697</td>\n",
       "      <td>0.464937</td>\n",
       "      <td>4200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>50</th>\n",
       "      <td>0.868673</td>\n",
       "      <td>0.423369</td>\n",
       "      <td>5000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>49</th>\n",
       "      <td>0.868673</td>\n",
       "      <td>0.423369</td>\n",
       "      <td>5000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28</th>\n",
       "      <td>0.868541</td>\n",
       "      <td>0.460255</td>\n",
       "      <td>2900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>37</th>\n",
       "      <td>0.868487</td>\n",
       "      <td>0.471145</td>\n",
       "      <td>3800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>35</th>\n",
       "      <td>0.868377</td>\n",
       "      <td>0.460122</td>\n",
       "      <td>3600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>0.868006</td>\n",
       "      <td>0.462435</td>\n",
       "      <td>1500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>42</th>\n",
       "      <td>0.867806</td>\n",
       "      <td>0.456845</td>\n",
       "      <td>4300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>36</th>\n",
       "      <td>0.867270</td>\n",
       "      <td>0.427169</td>\n",
       "      <td>3700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>48</th>\n",
       "      <td>0.867188</td>\n",
       "      <td>0.458979</td>\n",
       "      <td>4900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>0.867029</td>\n",
       "      <td>0.447451</td>\n",
       "      <td>2700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>0.866746</td>\n",
       "      <td>0.441671</td>\n",
       "      <td>900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>0.866587</td>\n",
       "      <td>0.469412</td>\n",
       "      <td>1400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>45</th>\n",
       "      <td>0.865915</td>\n",
       "      <td>0.430296</td>\n",
       "      <td>4600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>0.865825</td>\n",
       "      <td>0.442324</td>\n",
       "      <td>2500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>0.865808</td>\n",
       "      <td>0.481119</td>\n",
       "      <td>2300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>0.865637</td>\n",
       "      <td>0.479730</td>\n",
       "      <td>1300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>0.865492</td>\n",
       "      <td>0.464096</td>\n",
       "      <td>2400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>0.865104</td>\n",
       "      <td>0.416445</td>\n",
       "      <td>3000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>33</th>\n",
       "      <td>0.864536</td>\n",
       "      <td>0.443372</td>\n",
       "      <td>3400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>0.863944</td>\n",
       "      <td>0.408261</td>\n",
       "      <td>1000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>0.863824</td>\n",
       "      <td>0.457692</td>\n",
       "      <td>2800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>0.863208</td>\n",
       "      <td>0.446406</td>\n",
       "      <td>2000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32</th>\n",
       "      <td>0.863170</td>\n",
       "      <td>0.436925</td>\n",
       "      <td>3300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>0.863047</td>\n",
       "      <td>0.420563</td>\n",
       "      <td>600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>0.862754</td>\n",
       "      <td>0.420830</td>\n",
       "      <td>1100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>0.862672</td>\n",
       "      <td>0.429154</td>\n",
       "      <td>1700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>0.862195</td>\n",
       "      <td>0.428311</td>\n",
       "      <td>700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>30</th>\n",
       "      <td>0.861190</td>\n",
       "      <td>0.443201</td>\n",
       "      <td>3100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>43</th>\n",
       "      <td>0.860763</td>\n",
       "      <td>0.426792</td>\n",
       "      <td>4400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>0.860157</td>\n",
       "      <td>0.416387</td>\n",
       "      <td>1600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>0.859180</td>\n",
       "      <td>0.413060</td>\n",
       "      <td>800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>0.858605</td>\n",
       "      <td>0.443871</td>\n",
       "      <td>1200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.857223</td>\n",
       "      <td>0.402336</td>\n",
       "      <td>500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.856677</td>\n",
       "      <td>0.389314</td>\n",
       "      <td>400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>0.856077</td>\n",
       "      <td>0.407695</td>\n",
       "      <td>2600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>31</th>\n",
       "      <td>0.852730</td>\n",
       "      <td>0.360912</td>\n",
       "      <td>3200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>39</th>\n",
       "      <td>0.845266</td>\n",
       "      <td>0.333785</td>\n",
       "      <td>4000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.841170</td>\n",
       "      <td>0.298772</td>\n",
       "      <td>300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.835092</td>\n",
       "      <td>0.300258</td>\n",
       "      <td>200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.804493</td>\n",
       "      <td>0.075339</td>\n",
       "      <td>100</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      1-mape        r2  step\n",
       "40  0.874003  0.463630  4100\n",
       "17  0.873947  0.488478  1800\n",
       "20  0.873017  0.483799  2100\n",
       "44  0.872835  0.452362  4500\n",
       "34  0.871858  0.480349  3500\n",
       "47  0.871580  0.444040  4800\n",
       "21  0.871447  0.476700  2200\n",
       "38  0.870389  0.445788  3900\n",
       "18  0.870136  0.470281  1900\n",
       "46  0.869319  0.435511  4700\n",
       "41  0.868697  0.464937  4200\n",
       "50  0.868673  0.423369  5000\n",
       "49  0.868673  0.423369  5000\n",
       "28  0.868541  0.460255  2900\n",
       "37  0.868487  0.471145  3800\n",
       "35  0.868377  0.460122  3600\n",
       "14  0.868006  0.462435  1500\n",
       "42  0.867806  0.456845  4300\n",
       "36  0.867270  0.427169  3700\n",
       "48  0.867188  0.458979  4900\n",
       "26  0.867029  0.447451  2700\n",
       "8   0.866746  0.441671   900\n",
       "13  0.866587  0.469412  1400\n",
       "45  0.865915  0.430296  4600\n",
       "24  0.865825  0.442324  2500\n",
       "22  0.865808  0.481119  2300\n",
       "12  0.865637  0.479730  1300\n",
       "23  0.865492  0.464096  2400\n",
       "29  0.865104  0.416445  3000\n",
       "33  0.864536  0.443372  3400\n",
       "9   0.863944  0.408261  1000\n",
       "27  0.863824  0.457692  2800\n",
       "19  0.863208  0.446406  2000\n",
       "32  0.863170  0.436925  3300\n",
       "5   0.863047  0.420563   600\n",
       "10  0.862754  0.420830  1100\n",
       "16  0.862672  0.429154  1700\n",
       "6   0.862195  0.428311   700\n",
       "30  0.861190  0.443201  3100\n",
       "43  0.860763  0.426792  4400\n",
       "15  0.860157  0.416387  1600\n",
       "7   0.859180  0.413060   800\n",
       "11  0.858605  0.443871  1200\n",
       "4   0.857223  0.402336   500\n",
       "3   0.856677  0.389314   400\n",
       "25  0.856077  0.407695  2600\n",
       "31  0.852730  0.360912  3200\n",
       "39  0.845266  0.333785  4000\n",
       "2   0.841170  0.298772   300\n",
       "1   0.835092  0.300258   200\n",
       "0   0.804493  0.075339   100"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.sort_values(\"1-mape\", ascending=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
